import matplotlib.pyplot as plt
plt.figure()
plt.plot(list(range(epochs)), train_loss, label='train_loss')  # 训练集损失
plt.legend()  # 显示标签
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()
